// Copyright (c) 2020-present, INSPUR Co, Ltd. All rights reserved.
// This source code is licensed under Apache 2.0 License.

#ifndef __STDC_FORMAT_MACROS
#define __STDC_FORMAT_MACROS
#endif

#ifdef GFLAGS
#ifdef NUMA
#include <numa.h>
#include <numaif.h>
#endif
#ifndef OS_WIN
#include <unistd.h>
#endif
#include <fcntl.h>
#include <inttypes.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/types.h>
#include <atomic>
#include <condition_variable>
#include <cstddef>
#include <memory>
#include <mutex>
#include <thread>
#include <unordered_map>

#include "port/port.h"
#include "port/stack_trace.h"
#include "util/gflags_compat.h"
#include "util/testutil.h"
#include "util/random.h"
#include "abstract_tree.h"
#include "pure_mem/key_index/art/rowex_tree.h"
#include "pure_mem/key_index/art/rowex_tree_fullkey.h"
#include "pure_mem/rangearena/thread_safe_sorted_list.h"


#ifdef OS_WIN
#include <io.h>  // open/close
#endif

using GFLAGS_NAMESPACE::ParseCommandLineFlags;
using GFLAGS_NAMESPACE::RegisterFlagValidator;
using GFLAGS_NAMESPACE::SetUsageMessage;


static rocksdb::Env* FLAGS_env = rocksdb::Env::Default();

DEFINE_string(
    tbenchmarks,
    "fillseq,"
    "fillrandom,"
    "fillurandom,"
    "readseq,"
    "readrandom,"
    "readurandom,"
    "multireadrandom,",

    "Comma-separated list of operations to run in the specified"
    " order. Available benchmarks:\n"
    "\tfillseq       -- write N values in sequential key"
    " order in async mode\n");

DEFINE_string(
    treetype,
    "art,"
    "art_als,",

    "Comma-separated list of operations to run in the specified"
    " order. Available benchmarks:\n"
    "\tart       -- adaptive radix tree with rowex.\n"
    "\tart_als   -- art with node build by adaptive length struct tech. \n"
    "\tbtree     -- b tree.\n");

DEFINE_int64(tnum, 1000000, "Number of key/values to place in database");


DEFINE_int64(treads, -1, "Number of read operations to do.  "
             "If negative, do FLAGS_kvnum reads.");

DEFINE_int64(tseed, 0, "Seed base for random number generators. "
             "When 0 it is deterministic.");

DEFINE_int32(tthreads, 1, "Number of concurrent threads to run.");

DEFINE_int32(tduration, 0, "Time in seconds for the random-ops tests to run."
             " When 0 then num & reads determine the test duration");

DEFINE_int32(tprefix_size, 0, "prefix_size");
DEFINE_int32(tkey_size, 16, "size of each key");
DEFINE_int32(tvalue_size, 100, "Size of each value");


static bool ValidateKeySize(const char* /*flagname*/, int32_t /*value*/) {
  return true;
}

DEFINE_double(tread_random_exp_range, 0.0,
              "Read random's key will be generated using distribution of "
              "num * exp(-r) where r is uniform number from 0 to this value. "
              "The larger the number is, the more skewed the reads are. "
              "Only used in readrandom and multireadrandom benchmarks.");

DEFINE_int64(twrites, -1, "Number of write operations to do. If negative, do"
             " --num reads.");


DEFINE_int32(treadwritepercent, 90, "Ratio of reads to reads/writes (expressed"
             " as percentage) for the ReadRandomWriteRandom workload. The "
             "default value 90 means 90% operations out of all reads and writes"
             " operations are reads. In other words, 9 gets for every 1 put.");

DEFINE_int32(tperf_level, rocksdb::PerfLevel::kDisable, "Level of perf collection");

DEFINE_double(tsine_a, 1,
             "A in f(x) = A sin(bx + c) + d");

DEFINE_double(tsine_b, 1,
             "B in f(x) = A sin(bx + c) + d");

DEFINE_double(tsine_c, 0,
             "C in f(x) = A sin(bx + c) + d");

DEFINE_double(tsine_d, 1,
             "D in f(x) = A sin(bx + c) + d");

static const bool FLAGS_key_size_dummy __attribute__((__unused__)) =
    RegisterFlagValidator(&FLAGS_tkey_size, &ValidateKeySize);

namespace rocksdb{

struct TreeLeafNode{
  Slice key;
  Slice value;
  TreeLeafNode(Slice& k, Slice& v):key(k),value(v){}
};


// Helper for quickly generating random data.
class RandomGenerator {
 private:
  std::string data_;
  unsigned int pos_;

 public:
  RandomGenerator() {
    // We use a limited amount of data over and over again and ensure
    // that large enough to serve all typical value sizes we want to write.
    Random rnd(301);
    std::string piece;
    while (data_.size() < (unsigned)std::max(1048576, FLAGS_tvalue_size)) {
      test::CompressibleString(&rnd, 0.5, 100, &piece);
      data_.append(piece);
    }
    pos_ = 0;
  }

  Slice Generate(unsigned int len) {
    assert(len <= data_.size());
    if (pos_ + len > data_.size()) {
      pos_ = 0;
    }
    pos_ += len;
    return Slice(data_.data() + pos_ - len, len);
  }

  Slice GenerateWithTTL(unsigned int len) {
    assert(len <= data_.size());
    if (pos_ + len > data_.size()) {
      pos_ = 0;
    }
    pos_ += len;
    return Slice(data_.data() + pos_ - len, len);
  }
};

static void AppendWithSpace(std::string* str, Slice msg) {
  if (msg.empty()) return;
  if (!str->empty()) {
    str->push_back(' ');
  }
  str->append(msg.data(), msg.size());
}


enum OperationType : unsigned char {
  kRead = 0,
  kWrite,
  kDelete,
  kSeek,
  kMerge,
  kUpdate,
  kCompress,
  kUncompress,
  kCrc,
  kHash,
  kOthers
};


class Stats {
 private:
  int id_;
  uint64_t start_;
  uint64_t sine_interval_;
  uint64_t finish_;
  double seconds_;
  uint64_t done_;
  uint64_t last_report_done_;
  uint64_t next_report_;
  uint64_t bytes_;
  uint64_t last_op_finish_;
  uint64_t last_report_finish_;
  std::unordered_map<OperationType, std::shared_ptr<HistogramImpl>,
                     std::hash<unsigned char>> hist_;
  std::string message_;
  bool exclude_from_merge_;
  friend class CombinedStats;

 public:
  Stats() { Start(-1); }


  void Start(int id) {
    id_ = id;
    last_op_finish_ = start_;
    hist_.clear();
    done_ = 0;
    last_report_done_ = 0;
    bytes_ = 0;
    seconds_ = 0;
    start_ = FLAGS_env->NowMicros();
    sine_interval_ = FLAGS_env->NowMicros();
    finish_ = start_;
    last_report_finish_ = start_;
    message_.clear();
    // When set, stats from this thread won't be merged with others.
    exclude_from_merge_ = false;
  }

  void Merge(const Stats& other) {
    if (other.exclude_from_merge_)
      return;

    for (auto it = other.hist_.begin(); it != other.hist_.end(); ++it) {
      auto this_it = hist_.find(it->first);
      if (this_it != hist_.end()) {
        this_it->second->Merge(*(other.hist_.at(it->first)));
      } else {
        hist_.insert({ it->first, it->second });
      }
    }

    done_ += other.done_;
    bytes_ += other.bytes_;
    seconds_ += other.seconds_;
    if (other.start_ < start_) start_ = other.start_;
    if (other.finish_ > finish_) finish_ = other.finish_;

    // Just keep the messages from one thread
    if (message_.empty()) message_ = other.message_;
  }

  void Stop() {
    finish_ = FLAGS_env->NowMicros();
    seconds_ = (finish_ - start_) * 1e-6;
  }

  void AddMessage(Slice msg) {
    AppendWithSpace(&message_, msg);
  }

  void SetId(int id) { id_ = id; }
  void SetExcludeFromMerge() { exclude_from_merge_ = true; }

  void ResetSineInterval() {
    sine_interval_ = FLAGS_env->NowMicros();
  }

  uint64_t GetSineInterval() {
    return sine_interval_;
  }

  uint64_t GetStart() {
    return start_;
  }

  void AddBytes(int64_t n) {
    bytes_ += n;
  }

  void Report(const Slice& name) {
    // Pretend at least one op was done in case we are running a benchmark
    // that does not call FinishedOps().
    if (done_ < 1) done_ = FLAGS_tnum;

    std::string extra;
    if (bytes_ > 0) {
      // Rate is computed on actual elapsed time, not the sum of per-thread
      // elapsed times.
      double elapsed = (finish_ - start_) * 1e-6;
      char rate[100];
      snprintf(rate, sizeof(rate), "%6.1f MB/s",
               (bytes_ / 1048576.0) / elapsed);
      extra = rate;
    }
    AppendWithSpace(&extra, message_);
    double elapsed = (finish_ - start_) * 1e-6;
    double throughput = (double)done_/elapsed;

    fprintf(stdout, "%-12s : %11.3f micros/op %ld ops/sec;%s%s\n",
            name.ToString().c_str(),
            seconds_ * 1e6 / done_,
            (long)throughput,
            (extra.empty() ? "" : " "),
            extra.c_str());
    
    fflush(stdout);
  }
};


class CombinedStats {
 public:
  void AddStats(const Stats& stat) {
    uint64_t total_ops = stat.done_;
    uint64_t total_bytes_ = stat.bytes_;
    double elapsed;

    if (total_ops < 1) {
      total_ops = 1;
    }

    elapsed = (stat.finish_ - stat.start_) * 1e-6;
    throughput_ops_.emplace_back(total_ops / elapsed);

    if (total_bytes_ > 0) {
      double mbs = (total_bytes_ / 1048576.0);
      throughput_mbs_.emplace_back(mbs / elapsed);
    }
  }

  void Report(const std::string& bench_name) {
    const char* name = bench_name.c_str();
    int num_runs = static_cast<int>(throughput_ops_.size());

    if (throughput_mbs_.size() == throughput_ops_.size()) {
      fprintf(stdout,
              "%s [AVG    %d runs] : %d ops/sec; %6.1f MB/sec\n"
              "%s [MEDIAN %d runs] : %d ops/sec; %6.1f MB/sec\n",
              name, num_runs, static_cast<int>(CalcAvg(throughput_ops_)),
              CalcAvg(throughput_mbs_), name, num_runs,
              static_cast<int>(CalcMedian(throughput_ops_)),
              CalcMedian(throughput_mbs_));
    } else {
      fprintf(stdout,
              "%s [AVG    %d runs] : %d ops/sec\n"
              "%s [MEDIAN %d runs] : %d ops/sec\n",
              name, num_runs, static_cast<int>(CalcAvg(throughput_ops_)), name,
              num_runs, static_cast<int>(CalcMedian(throughput_ops_)));
    }
  }

 private:
  double CalcAvg(std::vector<double> data) {
    double avg = 0;
    for (double x : data) {
      avg += x;
    }
    avg = avg / data.size();
    return avg;
  }

  double CalcMedian(std::vector<double> data) {
    assert(data.size() > 0);
    std::sort(data.begin(), data.end());

    size_t mid = data.size() / 2;
    if (data.size() % 2 == 1) {
      // Odd number of entries
      return data[mid];
    } else {
      // Even number of entries
      return (data[mid] + data[mid - 1]) / 2;
    }
  }

  std::vector<double> throughput_ops_;
  std::vector<double> throughput_mbs_;
};

// State shared by all concurrent executions of the same benchmark.
struct SharedState {
  port::Mutex mu;
  port::CondVar cv;
  int total;
  int perf_level;
  std::shared_ptr<RateLimiter> write_rate_limiter;
  std::shared_ptr<RateLimiter> read_rate_limiter;

  // Each thread goes through the following states:
  //    (1) initializing
  //    (2) waiting for others to be initialized
  //    (3) running
  //    (4) done

  long num_initialized;
  long num_done;
  bool start;

  SharedState() : cv(&mu), perf_level(FLAGS_tperf_level) { }
};

// Per-thread state for concurrent executions of the same benchmark.
struct ThreadState {
  int tid;             // 0..n-1 when running in n threads
  Random64 rand;         // Has different seeds for different threads
  Stats stats;
  SharedState* shared;
  ITree* tree_;

  /* implicit */ ThreadState(int index)
      : tid(index),
        rand((FLAGS_tseed ? FLAGS_tseed : 1000) + index) {
  }
};



void loadKeyFromTreeLeafNode(void *node, Slice &key) {
  key = ((TreeLeafNode*)node)->key;
}

class Duration {
 public:
  Duration(int64_t max_ops) {
    max_ops_= max_ops;
    ops_ = 0;
    start_at_ = FLAGS_env->NowMicros();
  }

  bool Done(int64_t increment) {
    if (increment <= 0) increment = 1;    // avoid Done(0) and infinite loops
    ops_ += increment;
    if (ops_ <= max_ops_)
      return false;
      
    return true;
  }

 private:
  int64_t max_ops_;
  int64_t ops_;
  uint64_t start_at_;
};

class Benchmark {
 private:
  const SliceTransform* prefix_extractor_;
  int64_t num_;
  int value_size_;
  int key_size_;
  int prefix_size_;
  int64_t keys_per_prefix_ = 0;
  int64_t reads_;
  double read_random_exp_range_;
  int64_t writes_;
  int64_t readwrites_;
  std::vector<std::string> keys_;

  ITree *tree_;

  AtomicLinkedList<Slice> write_keys_;

 public:
  Benchmark():
        prefix_extractor_(NewFixedPrefixTransform(FLAGS_tprefix_size)),
        num_(FLAGS_tnum),
        value_size_(FLAGS_tvalue_size),
        key_size_(FLAGS_tkey_size),
        prefix_size_(FLAGS_tprefix_size),
        reads_(FLAGS_treads < 0 ? FLAGS_tnum : FLAGS_treads),
        read_random_exp_range_(0.0),
        writes_(FLAGS_twrites < 0 ? FLAGS_tnum : FLAGS_twrites),
        readwrites_(
            (FLAGS_twrites < 0 && FLAGS_treads < 0)
                ? FLAGS_tnum
                : ((FLAGS_twrites > FLAGS_treads) ? FLAGS_twrites : FLAGS_treads))
  {
    if (FLAGS_tprefix_size > FLAGS_tkey_size) {
      fprintf(stderr, "prefix size is larger than key size");
      exit(1);
    }
  }

  ~Benchmark() {}

  Slice AllocateKey() {
    char* data = new char[key_size_];
    return Slice(data, key_size_);
  }

  // Generate key according to the given specification and random number.
  // The resulting key will have the following format (if keys_per_prefix_
  // is positive), extra trailing bytes are either cut off or padded with '0'.
  // The prefix value is derived from key value.
  //   ----------------------------
  //   | prefix 00000 | key 00000 |
  //   ----------------------------
  // If keys_per_prefix_ is 0, the key is simply a binary representation of
  // random number followed by trailing '0's
  //   ----------------------------
  //   |        key 00000         |
  //   ----------------------------
  void GenerateKeyFromInt(uint64_t v, int64_t num_keys, Slice* key) {
    if (!keys_.empty()) {
      assert(keys_.size() == static_cast<size_t>(num_keys));
      assert(v < static_cast<uint64_t>(num_keys));
      *key = keys_[v];
      return;
    }
    char* start = const_cast<char*>(key->data());
    char* pos = start;
    if (keys_per_prefix_ > 0) {
      int64_t num_prefix = num_keys / keys_per_prefix_;
      int64_t prefix = v % num_prefix;
      int bytes_to_fill = std::min(prefix_size_, 8);
      if (port::kLittleEndian) {
        for (int i = 0; i < bytes_to_fill; ++i) {
          pos[i] = (prefix >> ((bytes_to_fill - i - 1) << 3)) & 0xFF;
        }
      } else {
        memcpy(pos, static_cast<void*>(&prefix), bytes_to_fill);
      }
      if (prefix_size_ > 8) {
        // fill the rest with 0s
        memset(pos + 8, '0', prefix_size_ - 8);
      }
      pos += prefix_size_;
    }

    int bytes_to_fill = std::min(key_size_ - static_cast<int>(pos - start), 8);
    if (port::kLittleEndian) {
      for (int i = 0; i < bytes_to_fill; ++i) {
        pos[i] = (v >> ((bytes_to_fill - i - 1) << 3)) & 0xFF;
      }
    } else {
      memcpy(pos, static_cast<void*>(&v), bytes_to_fill);
    }
    pos += bytes_to_fill;
    if (key_size_ > pos - start) {
      memset(pos, 0, key_size_ - (pos - start));
    }
  }

  void Run() {
    Slice treeType(FLAGS_treetype.data(), FLAGS_treetype.size());
    if (treeType.compare("art") == 0){
      tree_ = new RowexTree(loadKeyFromTreeLeafNode);
    }else if (treeType.compare("art_als") == 0){
      tree_ = new RowexTreeFullKey(loadKeyFromTreeLeafNode, nullptr);
    }else{
      assert(false);
    }

    std::stringstream benchmark_stream(FLAGS_tbenchmarks);
    std::string name;
    while (std::getline(benchmark_stream, name, ',')) {
      // Sanitize parameters
      num_ = FLAGS_tnum;
      reads_ = (FLAGS_treads < 0 ? FLAGS_tnum : FLAGS_treads);
      writes_ = (FLAGS_twrites < 0 ? FLAGS_tnum : FLAGS_twrites);
      value_size_ = FLAGS_tvalue_size;
      key_size_ = FLAGS_tkey_size;
      read_random_exp_range_ = FLAGS_tread_random_exp_range;

      void (Benchmark::*method)(ThreadState*) = nullptr;
      void (Benchmark::*post_process_method)() = nullptr;

      int num_threads = FLAGS_tthreads;

      int num_repeat = 1;
      int num_warmup = 0;
      if (!name.empty() && *name.rbegin() == ']') {
        auto it = name.find('[');
        if (it == std::string::npos) {
          fprintf(stderr, "unknown benchmark arguments '%s'\n", name.c_str());
          exit(1);
        }
        std::string args = name.substr(it + 1);
        args.resize(args.size() - 1);
        name.resize(it);

        std::string bench_arg;
        std::stringstream args_stream(args);
        while (std::getline(args_stream, bench_arg, '-')) {
          if (bench_arg.empty()) {
            continue;
          }
          if (bench_arg[0] == 'X') {
            // Repeat the benchmark n times
            std::string num_str = bench_arg.substr(1);
            num_repeat = std::stoi(num_str);
          } else if (bench_arg[0] == 'W') {
            // Warm up the benchmark for n times
            std::string num_str = bench_arg.substr(1);
            num_warmup = std::stoi(num_str);
          }
        }
      }

      if (name == "fillseq") {
        method = &Benchmark::WriteSeq;
      } else if (name == "fillrandom") {
        method = &Benchmark::WriteRandom;
      } else if (name == "fillurandom") {
        method = &Benchmark::WriteUniqueRandom;
      } else if (name == "fillalrandom") {
        method = &Benchmark::WriteAdapLengthRandom;
      }else if (name == "readrandom") {
        method = &Benchmark::ReadRandom;
      } else if (name == "readallwrites") {
        method = &Benchmark::ReadAllWrites;
      } else if (name == "readseq") {
        method = &Benchmark::ReadSeq;
      } else if (!name.empty()) {  // No error message for empty name
        fprintf(stderr, "unknown benchmark '%s'\n", name.c_str());
        exit(1);
      }

      if (method != nullptr) {
        if (num_warmup > 0) {
          printf("Warming up benchmark by running %d times\n", num_warmup);
        }

        for (int i = 0; i < num_warmup; i++) {
          RunBenchmark(num_threads, name, method);
        }

        if (num_repeat > 1) {
          printf("Running benchmark for %d times\n", num_repeat);
        }

        CombinedStats combined_stats;
        for (int i = 0; i < num_repeat; i++) {
          Stats stats = RunBenchmark(num_threads, name, method);
          combined_stats.AddStats(stats);
        }
        if (num_repeat > 1) {
          combined_stats.Report(name);
        }
      }
      if (post_process_method != nullptr) {
        (this->*post_process_method)();
      }
    }
  }

 private:
  struct ThreadArg {
    Benchmark* bm;
    SharedState* shared;
    ThreadState* thread;
    void (Benchmark::*method)(ThreadState*);
  };

  static void ThreadBody(void* v) {
    ThreadArg* arg = reinterpret_cast<ThreadArg*>(v);
    SharedState* shared = arg->shared;
    ThreadState* thread = arg->thread;
    {
      MutexLock l(&shared->mu);
      shared->num_initialized++;
      if (shared->num_initialized >= shared->total) {
        shared->cv.SignalAll();
      }
      while (!shared->start) {
        shared->cv.Wait();
      }
    }

    SetPerfLevel(static_cast<PerfLevel> (shared->perf_level));
    perf_context.EnablePerLevelPerfContext();
    thread->stats.Start(thread->tid);
    (arg->bm->*(arg->method))(thread);
    thread->stats.Stop();

    {
      MutexLock l(&shared->mu);
      shared->num_done++;
      if (shared->num_done >= shared->total) {
        shared->cv.SignalAll();
      }
    }
  }

  Stats RunBenchmark(int n, Slice name,
                     void (Benchmark::*method)(ThreadState*)) {
    SharedState shared;
    shared.total = n;
    shared.num_initialized = 0;
    shared.num_done = 0;
    shared.start = false;

    ThreadArg* arg = new ThreadArg[n];

    for (int i = 0; i < n; i++) {
      arg[i].bm = this;
      arg[i].method = method;
      arg[i].shared = &shared;
      arg[i].thread = new ThreadState(i);
      arg[i].thread->shared = &shared;
      FLAGS_env->StartThread(ThreadBody, &arg[i]);
    }

    shared.mu.Lock();
    while (shared.num_initialized < n) {
      shared.cv.Wait();
    }

    shared.start = true;
    shared.cv.SignalAll();
    while (shared.num_done < n) {
      shared.cv.Wait();
    }
    shared.mu.Unlock();

    // Stats for some threads can be excluded.
    Stats merge_stats;
    for (int i = 0; i < n; i++) {
      merge_stats.Merge(arg[i].thread->stats);
    }
    merge_stats.Report(name);

    for (int i = 0; i < n; i++) {
      delete arg[i].thread;
    }
    delete[] arg;

    return merge_stats;
  }

  enum WriteMode {
    RANDOM, SEQUENTIAL, UNIQUE_RANDOM, A_LENGTH_U_RANDOM
  };


  void WriteSeq(ThreadState* thread) {
    DoWrite(thread, SEQUENTIAL);
  }

  void WriteRandom(ThreadState* thread) {
    DoWrite(thread, RANDOM);
  }

  void WriteUniqueRandom(ThreadState* thread) {
    DoWrite(thread, UNIQUE_RANDOM);
  }
  void WriteAdapLengthRandom(ThreadState* thread) {
    DoWrite(thread, A_LENGTH_U_RANDOM);
  }
  
  class KeyGenerator {
   public:
    KeyGenerator(Random64* rand, WriteMode mode, uint64_t num,
                 uint64_t /*num_per_set*/ = 64 * 1024)
        : rand_(rand), mode_(mode), num_(num), next_(0) {
      if (mode_ == UNIQUE_RANDOM) {
        values_.resize(num_);
        for (uint64_t i = 0; i < num_; ++i) {
          values_[i] = i;
        }
        std::shuffle(
            values_.begin(), values_.end(),
            std::default_random_engine(static_cast<unsigned int>(FLAGS_tseed)));
      }else if(mode == A_LENGTH_U_RANDOM){
        keys_.resize(num_);
        uint32_t minLength = 20;

        for (size_t i = 0; i < num_; i++)
        {
          uint32_t curLength = rand_->Next() % minLength + minLength;
          uint8_t* data = new uint8_t[curLength];
          for(uint32_t j = 0; j< curLength; j++){
            data[j] = d(generator);
          }
          keys_[i].data_ = (char*)data;
          keys_[i].size_ = curLength;
        }
      }
    }

    uint64_t Next() {
      switch (mode_) {
        case SEQUENTIAL:
          return next_++;
        case RANDOM:
          return rand_->Next() % num_;
        case UNIQUE_RANDOM:
          assert(next_ < num_);
          return values_[next_++];
        case A_LENGTH_U_RANDOM:
          assert(false);
          return 0;
      }
      assert(false);
      return std::numeric_limits<uint64_t>::max();
    }

    Slice NextKey(){
      if (mode_ != A_LENGTH_U_RANDOM){
        assert(false);
      }
      assert(next_ < num_);
      return keys_[next_++];
    }

   private:
    Random64* rand_;
    WriteMode mode_;
    const uint64_t num_;
    uint64_t next_;
    std::vector<uint64_t> values_;    
    std::vector<Slice> keys_;
    std::default_random_engine generator;
    std::uniform_int_distribution<uint8_t> d;
  };

  double SineRate(double x) {
    return FLAGS_tsine_a*sin((FLAGS_tsine_b*x) + FLAGS_tsine_c) + FLAGS_tsine_d;
  }

  void DoWrite(ThreadState* thread, WriteMode write_mode) {
    const int test_duration = write_mode == RANDOM ? FLAGS_tduration : 0;
    const int64_t num_ops = writes_ == 0 ? num_ : writes_;

    KeyGenerator key_gens(&(thread->rand), write_mode, num_);
    Duration duration(num_ops);
    if (num_ != FLAGS_tnum) {
      char msg[100];
      snprintf(msg, sizeof(msg), "(%" PRIu64 " ops)", num_);
      thread->stats.AddMessage(msg);
    }

    RandomGenerator gen;
    int64_t bytes = 0;

    int64_t num_written = 0;
    int64_t num_repeat = 0;
    
    while (!duration.Done(1)) {
      Slice key;
      if (A_LENGTH_U_RANDOM == write_mode){
        key = key_gens.NextKey();
      }else{
      key = AllocateKey();
      GenerateKeyFromInt(key_gens.Next(), FLAGS_tnum, &key);
      }
      write_keys_.insertHead(key);

      Slice value = gen.Generate(value_size_);
      TreeLeafNode* kv = new TreeLeafNode(key, value);
      void *retKV=nullptr;
      bool ok = tree_->insertNoReplace(key, kv, retKV);
      if (ok){
      bytes += value_size_ + key_size_;
      ++num_written;
      }else{
        ++num_repeat;
      }
    }

    std::cout << "DoWrite: num_written:"<<num_written << ", num_repeat:"<< num_repeat<<std::endl;
    thread->stats.AddBytes(bytes);
  }

  int64_t GetRandomKey(Random64* rand) {
    uint64_t rand_int = rand->Next();
    int64_t key_rand;
    if (read_random_exp_range_ == 0) {
      key_rand = rand_int % FLAGS_tnum;
    } else {
      const uint64_t kBigInt = static_cast<uint64_t>(1U) << 62;
      long double order = -static_cast<long double>(rand_int % kBigInt) /
                          static_cast<long double>(kBigInt) *
                          read_random_exp_range_;
      long double exp_ran = std::exp(order);
      uint64_t rand_num =
          static_cast<int64_t>(exp_ran * static_cast<long double>(FLAGS_tnum));
      // Map to a different number to avoid locality.
      const uint64_t kBigPrime = 0x5bd1e995;
      // Overflow is like %(2^64). Will have little impact of results.
      key_rand = static_cast<int64_t>((rand_num * kBigPrime) % FLAGS_tnum);
    }
    return key_rand;
  }

  void ReadWithMode(ThreadState* thread, WriteMode write_mode) {
    int64_t read = 0;
    int64_t found = 0;
    int64_t bytes = 0;
    int num_keys = 0;
    Slice key = AllocateKey();
    PinnableSlice pinnable_val;

    KeyGenerator key_gens(&(thread->rand), write_mode, num_);
    Duration duration(reads_);
    while (!duration.Done(1)) {
      GenerateKeyFromInt(key_gens.Next(), FLAGS_tnum, &key);
      read++;
      void * kv = tree_->getGE(key);
      if (kv != nullptr){
        Slice retKey = ((TreeLeafNode*)kv)->key;
        if (retKey.compare(key) == 0){
          found++;
          bytes += key.size() + ((TreeLeafNode*)kv)->value.size();
        }
      }
    }

    char msg[100];
    snprintf(msg, sizeof(msg), "(%" PRIu64 " of %" PRIu64 " found)\n",
             found, read);

    thread->stats.AddBytes(bytes);
    thread->stats.AddMessage(msg);

    if (FLAGS_tperf_level > rocksdb::PerfLevel::kDisable) {
      thread->stats.AddMessage(std::string("PERF_CONTEXT:\n") +
                               get_perf_context()->ToString());
    }
  }


  void ReadSeq(ThreadState* thread){
   ReadWithMode(thread, SEQUENTIAL); 
  }

  void ReadAllWrites(ThreadState* thread){
    int64_t read = 0;
    int64_t found = 0;
    int64_t bytes = 0;

    Slice cur_key;
    while(!write_keys_.empty()){
      write_keys_.sweepHead(cur_key);
      read++;
      void * kv = tree_->getGE(cur_key);
      if (kv != nullptr){
        Slice retKey = ((TreeLeafNode*)kv)->key;
        if (retKey.compare(cur_key) == 0){
          found++;
          bytes += retKey.size() + ((TreeLeafNode*)kv)->value.size();
        }else{
          assert(false);
        }
      }
    }

    char msg[100];
    snprintf(msg, sizeof(msg), "(%" PRIu64 " of %" PRIu64 " found)\n",
             found, read);

    thread->stats.AddBytes(bytes);
    thread->stats.AddMessage(msg);

    if (FLAGS_tperf_level > rocksdb::PerfLevel::kDisable) {
      thread->stats.AddMessage(std::string("PERF_CONTEXT:\n") +
                               get_perf_context()->ToString());
    }
  }

  void ReadRandom(ThreadState* thread) {
    int64_t read = 0;
    int64_t found = 0;
    int64_t bytes = 0;
    int num_keys = 0;
    Slice key = AllocateKey();
    PinnableSlice pinnable_val;

    Duration duration(reads_);
    while (!duration.Done(1)) {
      // We use same key_rand as seed for key and column family so that we can
      // deterministically find the cfh corresponding to a particular key, as it
      // is done in DoWrite method.
      int64_t key_rand = GetRandomKey(&thread->rand);
      GenerateKeyFromInt(key_rand, FLAGS_tnum, &key);
      
      read++;
      void * kv = tree_->getGE(key);
      if (kv != nullptr){
        Slice retKey = ((TreeLeafNode*)kv)->key;
        if (retKey.compare(key) == 0){
          found++;
          bytes += key.size() + ((TreeLeafNode*)kv)->value.size();
        }
      }
    }

    char msg[100];
    snprintf(msg, sizeof(msg), "(%" PRIu64 " of %" PRIu64 " found)\n",
             found, read);

    thread->stats.AddBytes(bytes);
    thread->stats.AddMessage(msg);

    if (FLAGS_tperf_level > rocksdb::PerfLevel::kDisable) {
      thread->stats.AddMessage(std::string("PERF_CONTEXT:\n") +
                               get_perf_context()->ToString());
    }
  }

  // Add the noice to the QPS
  double AddNoise(double origin, double noise_ratio) {
    if (noise_ratio < 0.0 || noise_ratio > 1.0) {
      return origin;
    }
    int band_int = static_cast<int>(FLAGS_tsine_a);
    double delta = (rand() % band_int - band_int / 2) * noise_ratio;
    if (origin + delta < 0) {
      return origin;
    } else {
      return (origin + delta);
    }
  }

  // This is different from ReadWhileWriting because it does not use
  // an extra thread.
  void ReadRandomWriteRandom(ThreadState* thread) {
    RandomGenerator gen;
    int64_t found = 0;
    int get_weight = 0;
    int put_weight = 0;
    int64_t reads_done = 0;
    int64_t writes_done = 0;
    Duration duration(readwrites_);

    Slice key = AllocateKey();

    // the number of iterations is the larger of read_ or write_
    while (!duration.Done(1)) {
      GenerateKeyFromInt(thread->rand.Next() % FLAGS_tnum, FLAGS_tnum, &key);
      if (get_weight == 0 && put_weight == 0) {
        // one batch completed, reinitialize for next batch
        get_weight = FLAGS_treadwritepercent;
        put_weight = 100 - get_weight;
      }
      if (get_weight > 0) {
        // do all the gets first
        void *retKV = thread->tree_->getGE(key);
        if (retKV != nullptr && ((TreeLeafNode*)retKV)->key.compare(key) == 0){
          found++;
        }

        get_weight--;
        reads_done++;
      } else  if (put_weight > 0) {
        // then do all the corresponding number of puts
        // for all the gets we have done earlier
        Slice value = gen.Generate(value_size_);
        TreeLeafNode *kv = new TreeLeafNode(key, value);
        void* retv = nullptr;
        assert(thread->tree_->insertNoReplace(key, kv, retv));
        put_weight--;
        writes_done++;
      }
    }
    char msg[100];
    snprintf(msg, sizeof(msg), "( reads:%" PRIu64 " writes:%" PRIu64 \
             " total:%" PRIu64 " found:%" PRIu64 ")",
             reads_done, writes_done, readwrites_, found);
    thread->stats.AddMessage(msg);
  }
};

int tree_bench_tool(int argc, char** argv) {
  rocksdb::port::InstallStackTraceHandler();
  static bool initialized = false;
  if (!initialized) {
    SetUsageMessage(std::string("\nUSAGE:\n") + std::string(argv[0]) +
                    " [OPTIONS]...");
    initialized = true;
  }
  ParseCommandLineFlags(&argc, &argv, true);

  rocksdb::Benchmark benchmark;
  benchmark.Run();

  return 0;
}
}  // namespace rocksdb
#endif

#include <cstdio>
int main(int argc, char** argv) {
  #ifdef GFLAGS
   return rocksdb::tree_bench_tool(argc, argv); 
   #endif
  fprintf(stderr, "Please install gflags to run rocksdb tools\n");
  return 1;
}
